{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple PyTorch mnist experiment\n",
    "---\n",
    "\n",
    "<font color='red'> <h3>Tested with pytorch 1.3.1</h3></font>\n",
    "<font color='red'> <h3>Tested with torchvision 0.4.2</h3></font>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def wrapper():\n",
    "    import argparse\n",
    "    import torch\n",
    "    import torch.nn as nn\n",
    "    import torch.nn.functional as F\n",
    "    import torch.optim as optim\n",
    "    from torchvision import datasets, transforms\n",
    "    from torch.utils.tensorboard import SummaryWriter    \n",
    "    import os\n",
    "    from hops import tensorboard\n",
    "    from hops import hdfs\n",
    "\n",
    "\n",
    "    class Net(nn.Module):\n",
    "        def __init__(self):\n",
    "            super(Net, self).__init__()\n",
    "            self.conv1 = nn.Conv2d(1, 20, 5, 1)\n",
    "            self.conv2 = nn.Conv2d(20, 50, 5, 1)\n",
    "            self.fc1 = nn.Linear(4*4*50, 500)\n",
    "            self.fc2 = nn.Linear(500, 10)\n",
    "\n",
    "        def forward(self, x):\n",
    "            x = F.relu(self.conv1(x))\n",
    "            x = F.max_pool2d(x, 2, 2)\n",
    "            x = F.relu(self.conv2(x))\n",
    "            x = F.max_pool2d(x, 2, 2)\n",
    "            x = x.view(-1, 4*4*50)\n",
    "            x = F.relu(self.fc1(x))\n",
    "            x = self.fc2(x)\n",
    "            return F.log_softmax(x, dim=1)\n",
    "\n",
    "    def train(model, device, train_loader, optimizer, writer):\n",
    "        model.train()\n",
    "        for batch_idx, (data, target) in enumerate(train_loader):\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = F.nll_loss(output, target)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            if batch_idx % 10 == 0:\n",
    "                print('[{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n",
    "                    batch_idx * len(data), len(train_loader.dataset),\n",
    "                    100. * batch_idx / len(train_loader), loss.item()))\n",
    "                writer.add_scalar('Loss/train', loss.item(), batch_idx)\n",
    "\n",
    "    def test(model, device, test_loader):\n",
    "        model.eval()\n",
    "        test_loss = 0\n",
    "        correct = 0\n",
    "        with torch.no_grad():\n",
    "            for data, target in test_loader:\n",
    "                data, target = data.to(device), target.to(device)\n",
    "                output = model(data)\n",
    "                test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss\n",
    "                pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability\n",
    "                correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "\n",
    "        test_loss /= len(test_loader.dataset)\n",
    "\n",
    "        print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n",
    "            test_loss, correct, len(test_loader.dataset),\n",
    "            100. * correct / len(test_loader.dataset)))\n",
    "\n",
    "\n",
    "    # Training settings\n",
    "    use_cuda = torch.cuda.is_available()\n",
    "\n",
    "    torch.manual_seed(1)\n",
    "\n",
    "    device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "\n",
    "    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
    "    \n",
    "    # Set SummaryWriter to point to the local directory containing the tensorboard logs\n",
    "    writer = SummaryWriter(log_dir=tensorboard.logdir())\n",
    "\n",
    "    # The same working directory may be used multiple times if running multiple experiments\n",
    "    # Make sure we only download the dataset once\n",
    "    if not os.path.exists(os.getcwd() + '/MNIST'):\n",
    "        hdfs.log('----Downloading dataset to Spark executor working directory----')\n",
    "        #Copy dataset from project to local filesystem\n",
    "        hdfs.copy_to_local('TourData/mnist/MNIST')\n",
    "        hdfs.log('----------------------Download finished------------------------')\n",
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST(os.getcwd(), train=True, download=False,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=64, shuffle=True, **kwargs)\n",
    "    test_loader = torch.utils.data.DataLoader(\n",
    "        datasets.MNIST(os.getcwd(), train=False, download=False,\n",
    "                       transform=transforms.Compose([\n",
    "                           transforms.ToTensor(),\n",
    "                           transforms.Normalize((0.1307,), (0.3081,))\n",
    "                       ])),\n",
    "        batch_size=1000, shuffle=True, **kwargs)\n",
    "\n",
    "\n",
    "    model = Net().to(device)\n",
    "    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)\n",
    "\n",
    "    #Single epoch\n",
    "    train(model, device, train_loader, optimizer, writer)\n",
    "    test(model, device, test_loader)\n",
    "\n",
    "\n",
    "    torch.save(model.state_dict(), tensorboard.logdir() + \"mnist_cnn.pt\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from hops import experiment\n",
    "# Simple experiment, local_logdir=True to write to local FS and not HDFS (Look in Experiments dataset for contents after experiment)\n",
    "experiment.launch(wrapper, local_logdir=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PySpark",
   "language": "",
   "name": "pysparkkernel"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "python",
    "version": 2
   },
   "mimetype": "text/x-python",
   "name": "pyspark",
   "pygments_lexer": "python2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}